import re
from ParseResponse import parse_function_call_string
from dataclasses import dataclass
from typing import *
import random
import json
import sys


Pattern = r'\b\w+\s*\(.*?\)'
random.seed(42)


@dataclass
class Attribute:
    table_names: List[str]
    column_names: List[str]
    attributes: List[str]


def get_correct_header(header_str):
    header = json.loads(header_str)
    if header:
        if not isinstance(header[0], list):
            header = [header]
    return header


def extract_from_instruction(instr: Dict, attribute: Attribute, contextTree: Dict):
    function_name = instr['function_name']
    arguments = instr['keyword_arguments']
    if function_name == 'select_tables':
        if 'tableNames' in arguments:
            attribute.table_names = arguments['tableNames']
        if 'scope' in arguments:
            if arguments['scope'] == 'selection':
                attribute.table_names = [contextTree['activeTables'][0]['name']]
    if function_name == 'select_columns':
        if 'columnNames' in arguments:
            attribute.column_names = arguments['columnNames']
        if 'tableColumnNumbers' in arguments:
            indexes = json.loads(arguments['tableColumnNumbers'])
            header = get_correct_header(contextTree['activeTables'][0]['headers'])
            attribute.column_names = [header[0][int(i)-1] for i in indexes]
        if 'scope' in arguments:
            if arguments['scope'] == 'selection':
                header = get_correct_header(contextTree['activeTables'][0]['headers'])
                attribute.column_names.append(header[0][0])
    if function_name == 'conditional_format_columns':
        if 'fontColor' in arguments:
            attribute.attributes.append("fontColor")
        if 'fillSize' in arguments:
            attribute.attributes.append("fillSize")
        if 'size' in arguments:
            attribute.attributes.append("fontSize")
        if any(arg in ['bold', 'italic', 'underline', 'strikethrough'] for arg in arguments):
            attribute.attributes.append("fontStyle")
    return attribute


def parse_instruction(odsl: str):
    function_calls = re.findall(Pattern, odsl)
    functionCall = parse_function_call_string(function_calls[0])
    return functionCall


def get_attributes(odsl, context_tree) -> Attribute:
    odsl_instructions = odsl.split('\n')
    parsed_odsl = [parse_instruction(instr) for instr in odsl_instructions]
    attribute = Attribute([], [], [])
    for parsed in parsed_odsl:
        attribute = extract_from_instruction(parsed, attribute, context_tree)
    return attribute


def main():
    with open('SampleSingleTurnBenchmark.json', 'r') as f:
        benchmark = json.load(f)
    attribute = get_attributes(benchmark['acceptableResponses'][0], json.loads(benchmark['contextTree']))
    print(attribute)


if __name__ == "__main__":
    main()